import matplotlib.pyplot as plt
import numpy as np
from faster_caching import uncached_tokens_to_latency
import pandas as pd

NAME_RUN = ""
# make this as a global variable across files, including parallel_caching.py
def set_name_run(name):
    global NAME_RUN
    NAME_RUN = name
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from io import StringIO

def plot_inverted_stacked_bars_from_csv_string(csv_file, benchmark_labels=None, output_file='policy_improvements_stacked_bar.png', colors=None, title=None):
    """
    Plot a stacked bar chart from a copy-pasted CSV string of relative ratios.

    Parameters:
        csv_string (str): Copy-pasted tab-separated values including header
        benchmark_labels (list of str, optional): Labels for each benchmark (x-axis)
        output_file (str): File name for saving the plot
        colors (dict): Optional dict mapping each policy to a color
    """

    # Convert CSV string (with tabs) to DataFrame
    df = pd.read_csv(csv_file)
    # make them all as float
    df = df.applymap(lambda x: float(x))

    # Invert the values (e.g., 1.0 / latency)
    inverted_df = 1.0 / df

    # Define default color scheme if not provided
    if colors is None:
        colors = {
            'tail_belady': '#000000',
            'tail_lru_none': '#7BADED',
            'tail_lru_end': '#3D73D2',
            'tail_lru_perfect': '#1B325E',
            'vanilla_belady': '#9467bd',
            'vanilla_lru': '#8c564b',
            'thre_lru': '#FFD700'
        }

    # Set up x-axis
    num_benchmarks = len(df)
    benchmarks = benchmark_labels if benchmark_labels else [f"Benchmark {i+1}" for i in range(num_benchmarks)]
    x = np.arange(num_benchmarks)
    bar_width = 0.6

    # Plot
    fig, ax = plt.subplots(figsize=(5, 3))
    bottom = np.zeros(num_benchmarks)

    def pretty_print_capacity(capacity):
        return f"{capacity / 1000}K"
    
    def pretty_print_policy(policy):
        return f"{policy.replace('_', ' ').replace('lru', 'LRU').replace('tail ', 'T-')}"
    benchmarks = [pretty_print_capacity(capacity) for capacity in benchmark_labels]
    cols = list(inverted_df.columns)
    policies = [pretty_print_policy(policy) for policy in cols]
    # Compute deltas between each column (show improvement over previous)
    # cols = policies
    prev = np.zeros(num_benchmarks)
    
    for i, col in enumerate(cols):
        values = inverted_df[col].values
        if i == 0:
            delta = values  # First bar is just its value
        else:
            delta = values - inverted_df[cols[i-1]].values  # Difference from previous column
        bar_color = colors.get(col, '#cccccc')
        ax.bar(x, delta, bar_width, bottom=prev, label=policies[i], color=bar_color)
        prev += delta



    ax.set_xlabel('Capacity', fontsize=8)
    ax.set_ylabel('Relative Improvement', fontsize=8)
    ax.set_xticks(x)
    ax.set_ylim(0.8, 2.0)
    ax.set_xticklabels(benchmarks, rotation=0, ha='right')
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2), ncols = 3)
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    # if title:
    #     plt.title(title, loc='bottom')
        # plot title in the bottom
    # Move the title to the bottom of the plot
    if title:
        plt.suptitle("")  # Remove any existing suptitle
        # Add the title as a text box at the bottom center, just below the x-axis
        plt.figtext(0.5, 0.8, title, ha='center', va='top', fontsize=10)
    # Save and show
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    # plt.show()

def create_cache_policy_plots(
    C_values, 
    belady_results, 
    lru_results, 
    lru_end_results, 
    lru_perfect_results, 
    vanilla_lru_results, 
    thre_lru_results,
    percentiles, 
    title=None, 
    filename=None, 
    show_plot=False,
    log_scale=False,
    colors=None,
    markers=None,
    figsize=(18, 2.8),
    dpi=300
):
    """
    Create plots comparing different cache policies across multiple percentiles.
    
    Parameters:
    -----------
    C_values : list or array
        Cache capacity values used in the evaluation
    *_results : dict
        Dictionaries with percentiles as keys and result lists as values
    percentiles : list
        List of percentiles to plot
    title : str, optional
        Main title for the plot
    filename : str, optional
        If provided, save the plot to this filename
    show_plot : bool, default=False
        Whether to display the plot
    log_scale : bool or tuple, default=False
        Whether to use log scale. Can be single bool or tuple (log_x, log_y)
    colors : dict, optional
        Dictionary mapping policy names to colors
    markers : dict, optional
        Dictionary mapping policy names to marker styles
    figsize : tuple, default=(16, 12)
        Figure size in inches
    dpi : int, default=300
        Resolution for saved figure
        
    Returns:
    --------
    fig, axes : tuple
        Matplotlib figure and axes objects
    """
    # Define default colors and markers if not provided
    if colors is None:
        colors = {
            'tail_belady': '#000000', #black
            'tail_lru_none': '#7BADED',  # blue
            'tail_lru_end': '#3D73D2',   # blue
            'tail_lru_perfect': '#1B325E',  # blue
            'vanilla_lru': '#8c564b',      # brown
            'thre_lru': '#FFD700'      # gold
        }
    
    if markers is None:
        markers = {
            'tail_belady': 'o',      # circle
            'tail_lru_none': 's',    # square
            'tail_lru_end': 'D',     # diamond
            'tail_lru_perfect': 'P', # plus
            'vanilla_lru': 'v',       # triangle down
            'thre_lru': 'x'       # triangle down
        }
    
    # Determine number of subplots needed
    n_plots = len(percentiles)
    # rows = int(np.ceil(n_plots / 2))
    rows = 1
    cols = n_plots
    
    # Create figure and axes
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    
    # Set main title if provided
    # if title is None:
    #     title = "Tail Metrics of Uncached Tokens vs. Cache Capacity"
    # fig.suptitle(title, fontsize=16)
    
    # Flatten axes for easier iteration
    if n_plots > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    # Store lines for legend
    lines = []
    labels = []
    improvement_results = [] # store the relative latency improvement of tail_lru variants over the best baseline (vanilla_lru or thre_lru)
    improvement_results_dict = {} # store the relative latency improvement of tail_lru variants over the best baseline (vanilla_lru or thre_lru) for each percentile
    # import pdb; pdb.set_trace()
    # switch column and rows, now set the one of the index as the contexts at percentiles]
    # switch the columns and rows
    
    
#     def normalize_results(df, tail_lru_policy='tail_lru_none'):
#         policy_row = df.loc['policy']
#         policy_row = df.iloc[-1]   # Get the policy names (last row)
#         data_df = df.iloc[:-1]     # All rows except the last

# # Generate new column names as "<original_column>_<policy>"
        
        
#         df.columns = new_columns

#         # change the df's columns to add the name of the policy
#         # df.columns = df.columns.map(lambda x: f"{x}_{
#         # policy is a rows in df
#         # policy_row = df.loc['policy']
#         # df.columns = df.
#         # Boolean masks
#         tail_lru_mask = policy_row.str.startswith(tail_lru_policy)
#         vanilla_lru_mask = policy_row == 'vanilla_lru'

#         # Create a new DataFrame to store the normalized results
#         # normalized_df = pd.DataFrame(index=df.index[:-1])  # exclude 'policy' row
#         normalized_df = df.copy()

#         # Iterate over all columns where the policy is a tail_lru*
#         for col in df.columns[tail_lru_mask]:
#             # Find the matching percentile position in vanilla_lru
#             percentile = col  # Assuming the column label itself is the percentile (int or str)
            
#             # Find the corresponding vanilla_lru column with the same percentile
#             # First, extract columns where policy == vanilla_lru and match on column label
#             for vcol in df.columns[vanilla_lru_mask]:
#                 if vcol == col:
#                     import pdb; pdb.set_trace()
#                     normalized_df[col] = df.loc[df.index != 'policy', col].astype(float) / df.loc[df.index != 'policy', vcol].astype(float)
#                     break

#         # Optional: add policy row back if needed
#         normalized_df.loc['policy'] = df.loc['policy', normalized_df.columns]
#         return normalized_df
    
    def transform_uncached_tokens_to_latency(uncached_tokens_results):
        for i, p in enumerate(percentiles):
            for j, C in enumerate(C_values):
                try:
                    print(f"C: {C}, p: {p}, uncached_tokens_results[p][j]: {uncached_tokens_results[p][j]}")
                    uncached_tokens_results[p][j] = uncached_tokens_to_latency(uncached_tokens_results[p][j])
                except Exception as e:
                    print(f"Error: {e}")
                    # import pdb; pdb.set_trace()
        return (uncached_tokens_results)
    
    # for each result, transform the uncached tokens to latency
    belady_results = transform_uncached_tokens_to_latency(belady_results)
    lru_results = transform_uncached_tokens_to_latency(lru_results)
    lru_end_results = transform_uncached_tokens_to_latency(lru_end_results)
    lru_perfect_results = transform_uncached_tokens_to_latency(lru_perfect_results)
    vanilla_lru_results = transform_uncached_tokens_to_latency(vanilla_lru_results)
    thre_lru_results = transform_uncached_tokens_to_latency(thre_lru_results)

    
    # devide the tail_lru variants by the best baseline
    # save the results to a pickle file, concatenate all the results into a df, index as the name of the policy 
    # import pandas as pd
    # put dict into df first
    belady_results_df = pd.DataFrame(belady_results)
    
    lru_results_df = pd.DataFrame(lru_results)
    lru_end_results_df = pd.DataFrame(lru_end_results)
    lru_perfect_results_df = pd.DataFrame(lru_perfect_results)
    vanilla_lru_results_df = pd.DataFrame(vanilla_lru_results)
    thre_lru_results_df = pd.DataFrame(thre_lru_results)   
    # import pdb; pdb.set_trace()
    belady_results_df.index = C_values
    lru_results_df.index = C_values
    lru_end_results_df.index = C_values
    lru_perfect_results_df.index = C_values
    vanilla_lru_results_df.index = C_values
    thre_lru_results_df.index = C_values
    
    belady_results_df.loc['policy'] = ['tail_belady'] * belady_results_df.shape[1]
    lru_results_df.loc['policy'] = ['tail_lru_none'] * lru_results_df.shape[1]
    lru_end_results_df.loc['policy'] = ['tail_lru_end'] * lru_end_results_df.shape[1]
    lru_perfect_results_df.loc['policy'] = ['tail_lru_perfect'] * lru_perfect_results_df.shape[1]
    vanilla_lru_results_df.loc['policy'] = ['vanilla_lru'] * vanilla_lru_results_df.shape[1]
    thre_lru_results_df.loc['policy'] = ['thre_lru'] * thre_lru_results_df.shape[1]
    
    # also change the index to be the capacity 
    
    results_df = pd.concat([belady_results_df, lru_results_df, lru_end_results_df, lru_perfect_results_df, vanilla_lru_results_df, thre_lru_results_df], axis=1)
    # check if dir exists
    import os
    if not os.path.exists("./results/{}".format(NAME_RUN)):
        os.makedirs("./results/{}".format(NAME_RUN))
    results_df.to_csv("./results/{}/latency_results_all.csv".format(NAME_RUN))
    print("results_df: {}".format(results_df))
    # normalized_results_df = normalize_results(results_df)
    # normalized_results_df = normalize_results(results_df, 'tail_lru_end')
    # normalized_results_df = normalize_results(results_df, 'tail_lru_perfect')
    # normalized_results_df = normalize_results(results_df, 'thre_lru')
    # normalized_results_df.to_csv("./results/latency_results_all_normalized.csv")
    
    # def compute_relative_improvement(results_df):
    #     # compute the relative improvement of tail_lru variants over the best baseline (vanilla_lru or thre_lru)
    #     # for each percentile
    #     # keep the same shape as results_df, but only devide the tail_lru variants by the best baseline
    #     results_df_relative_improvement = results_df.copy()
    #     for i in results_df_relative_improvement.iterrows():
            
    #     for p in percentiles:
    #         # compute the relative improvement of tail_lru variants over the best baseline (vanilla_lru or thre_lru)
    #         for c in C_values:
    #             if results_df.loc['policy', p] == 'tail_lru_none':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['tail_lru_none', p][c]) / results_df.loc['vanilla_lru', p][c]
    #             elif results_df.loc['policy', p] == 'tail_lru_end':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['tail_lru_end', p][c]) / results_df.loc['vanilla_lru', p][c]
    #             elif results_df.loc['policy', p] == 'tail_lru_perfect':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['tail_lru_perfect', p][c]) / results_df.loc['vanilla_lru', p][c]
    #             elif results_df.loc['policy', p] == 'thre_lru':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['thre_lru', p][c]) / results_df.loc['vanilla_lru', p][c]
    #             elif results_df.loc['policy', p] == 'vanilla_lru':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['vanilla_lru', p][c]) / results_df.loc['vanilla_lru', p][c]
    #             elif results_df.loc['policy', p] == 'vanilla_belady':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['vanilla_belady', p][c]) / results_df.loc['vanilla_lru', p][c]
    #             elif results_df.loc['policy', p] == 'tail_belady':
    #                 results_df.loc['relative_improvement', p][c] = (results_df.loc['tail_belady', p][c]) / results_df.loc['vanilla_lru', p][c]
    #     return results_df
    # import pdb; pdb.set_trace()
    # results_df_relative_improvement = compute_relative_improvement(results_df)
    # results_df_relative_improvement.to_pickle("./results/latency_results_all_relative_improvement.pkl") 
    # add another row to identify the type of the results
    # results_df.iloc[len(results_df)] = ['tail_belady', 'tail_lru_none', 'tail_lru_end', 'tail_lru_perfect', 'vanilla_belady', 'vanilla_lru', 'thre_lru']
    # need to repeat for len(percentiles)

    
    # Plot each percentile in a separate subplot
    for i, p in enumerate(percentiles):
        if i < len(axes):
            ax = axes[i]
            
            # Plot all policies for this percentile
            line1, = ax.plot(C_values, belady_results[p], '-.', 
                          color=colors['tail_belady'], 
                          marker=markers['tail_belady'], 
                          label='Tail-Optimized Belady')
            
            line2, = ax.plot(C_values, lru_results[p], '-', 
                          color=colors['tail_lru_none'], 
                          marker=markers['tail_lru_none'], 
                          label='Tail-Optimized LRU')
            
            line3, = ax.plot(C_values, lru_end_results[p], '-', 
                          color=colors['tail_lru_end'], 
                          marker=markers['tail_lru_end'], 
                          label='End-Aware T-LRU')
            
            line4, = ax.plot(C_values, lru_perfect_results[p], '-', 
                          color=colors['tail_lru_perfect'], 
                          marker=markers['tail_lru_perfect'], 
                          label='Length-Aware T-LRU')
            
            line6, = ax.plot(C_values, vanilla_lru_results[p], '--', 
                          color=colors['vanilla_lru'], 
                          marker=markers['vanilla_lru'], 
                          label='LRU')
            
            line7, = ax.plot(C_values, thre_lru_results[p], '--', 
                          color=colors['thre_lru'], 
                          marker=markers['thre_lru'], 
                          label='Thre LRU')
            # Store lines for first subplot only (for legend)
            if i == 0:
                lines = [line1, line2, line3, line4, line6, line7]
                labels = ['Tail-Optimized Belady', 
                          'Tail-Optimized LRU', 
                          'End-Aware T-LRU', 
                          'Length-Aware T-LRU', 
                          'LRU',
                          'Threshold LRU']
                # only create the y axis labels for the first subplot
                ax.set_ylabel(f'Latency (s)')
            else:
                ax.set_ylabel('')
            
            # Set axis labels and title
            ax.set_xlabel('Cache Capacity (C)')
            # ax.set_ylabel(f'{p}th Percentile of Uncached Tokens')
            ax.set_title(f'{p}th Percentile')
            
            # Set log scale if requested
            if log_scale:
                if isinstance(log_scale, tuple):
                    if log_scale[0]:
                        ax.set_xscale('log')
                    if log_scale[1]:
                        ax.set_yscale('log')
                else:
                    ax.set_xscale('log')
            
            # Add grid
            ax.grid(True, alpha=0.3)
    
    # Add a common legend outside the subplots
    fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 1.15),
              fancybox=True, shadow=True, ncol=9, fontsize=12)
    
    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1, top=0.92)
    
    # Save figure if filename is provided
    if filename:
        plt.savefig(filename, dpi=dpi, bbox_inches='tight')
    
    print("Saving plot to {}".format(filename))
    # Show plot if requested
    if show_plot:
        plt.show()
    
    return fig, axes

# Additional specialized plotting function for comparing specific aspects
def create_comparative_analysis_plots(
    C_values,
    results_dict,
    percentiles,
    comparisons=None,
    figsize=(18, 10),
    filename=None,
    show_plot=False,
    dpi=300
):
    """
    Create specialized comparative plots for deeper analysis of cache policies.
    
    Parameters:
    -----------
    C_values : list or array
        Cache capacity values used in the evaluation
    results_dict : dict
        Dictionary containing all results with format:
        {policy_name: {percentile: values}}
    percentiles : list
        List of percentiles to plot
    comparisons : list of dicts, optional
        List of comparison specifications:
        [{'title': 'Title', 'policies': ['policy1', 'policy2'], 'type': 'diff|ratio'}]
    figsize : tuple, default=(18, 10)
        Figure size in inches
    filename : str, optional
        If provided, save the plot to this filename
    show_plot : bool, default=False
        Whether to display the plot
    dpi : int, default=300
        Resolution for saved figure
        
    Returns:
    --------
    fig, axes : tuple
        Matplotlib figure and axes objects
    """
    # Default comparisons if none provided
    if comparisons is None:
        comparisons = [
            {
                'title': 'Performance of Different Predictors',
                'policies': [
                    ('tail_lru_perfect', 'tail_lru_none', 'Perfect vs None'),
                    ('tail_lru_end', 'tail_lru_none', 'End vs None'),
                    ('tail_belady', 'tail_lru_none', 't-Belady vs T-LRU')
                ],
                'type': 'ratio'
            }
        ]
    
    # Organize results for easier access
    organized_results = {}
    
    # Convert results_dict format if needed
    if isinstance(results_dict, tuple) and len(results_dict) == 6:
        # Assume it's the tuple returned from run_parallel_cache_evaluation
        belady_results, lru_results, lru_end_results, lru_perfect_results, vanilla_lru_results, thre_lru_results = results_dict
        
        organized_results = {
            'tail_belady': belady_results,
            'tail_lru_none': lru_results,
            'tail_lru_end': lru_end_results,
            'tail_lru_perfect': lru_perfect_results,
            'vanilla_lru': vanilla_lru_results,
            'thre_lru': thre_lru_results
        }
    else:
        organized_results = results_dict
    
    # Create figure
    fig, axes = plt.subplots(len(comparisons), len(percentiles), figsize=figsize)
    
    # Handle single row or column case
    if len(comparisons) == 1 and len(percentiles) > 1:
        axes = np.array([axes])
    elif len(percentiles) == 1 and len(comparisons) > 1:
        axes = np.array([[ax] for ax in axes])
    elif len(comparisons) == 1 and len(percentiles) == 1:
        axes = np.array([[axes]])
    
    # Create plots
    for i, comparison in enumerate(comparisons):
        for j, p in enumerate(percentiles):
            ax = axes[i, j]
            
            # Plot each policy comparison
            for k, policy_pair in enumerate(comparison['policies']):
                if len(policy_pair) == 3:
                    policy1, policy2, label = policy_pair
                else:
                    policy1, policy2 = policy_pair
                    label = f"{policy1} vs {policy2}"
                
                # Extract data
                data1 = organized_results[policy1][p]
                data2 = organized_results[policy2][p]
                
                # Calculate comparison metric
                if comparison['type'] == 'diff':
                    metric = np.array(data2) - np.array(data1)  # Smaller is better, so policy1 - policy2
                    ylabel = 'Absolute Difference in Uncached Tokens'
                else:  # ratio
                    # Avoid division by zero
                    metric = np.array(data2) / np.array(data1)
                    ylabel = 'Ratio of Uncached Tokens'
                
                # Plot
                ax.plot(C_values, metric, 'o-', label=label)
            
            # Set axis labels and title
            ax.set_xlabel('Cache Capacity (C)')
            ax.set_ylabel(ylabel)
            
            if len(percentiles) > 1:
                ax.set_title(f"{comparison['title']} - {p}th Percentile")
            else:
                ax.set_title(f"{comparison['title']}")
            
            # Add reference line for ratio=1 or diff=0
            if comparison['type'] == 'ratio':
                ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5)
            else:
                ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
            
            # Add grid and legend
            ax.grid(True, alpha=0.3)
            ax.legend()
    
    # Adjust layout
    plt.tight_layout()
    
    # Save figure if filename is provided
    if filename:
        plt.savefig(filename, dpi=dpi, bbox_inches='tight')
    
    # Show plot if requested
    if show_plot:
        plt.show()
    
    return fig, axes

def create_xi_comparison_plots(
    C_values,
    results_by_xi,
    percentile,
    filename=None,
    title=None,
    show_plot=False,
    log_scale=(False, True),
    figsize=(10, 6),
    dpi=300
):
    """
    Create plots comparing different xi values for a specific percentile.
    
    Parameters:
    -----------
    C_values : list or array
        Cache capacity values used in the evaluation
    results_by_xi : dict
        Dictionary with xi values as keys and policy result dictionaries as values
    percentile : int
        The percentile to plot (e.g., 50, 90, 95, 99)
    filename : str, optional
        If provided, save the plot to this filename
    title : str, optional
        Main title for the plot
    show_plot : bool, default=False
        Whether to display the plot
    log_scale : tuple, default=(False, True)
        Whether to use log scale for (x, y) axes
    figsize : tuple, default=(10, 6)
        Figure size in inches
    dpi : int, default=300
        Resolution for saved figure
        
    Returns:
    --------
    fig, ax : tuple
        Matplotlib figure and axes objects
    """
    # Define colors and markers for policies
    colors = {
        'tail_belady': '#000000',  # black
        'tail_lru_none': '#7BADED',  # light blue
        'tail_lru_end': '#3D73D2',   # medium blue
        'tail_lru_perfect': '#1B325E',  # dark blue
        'vanilla_lru': '#8c564b',    # brown
        'thre_lru': '#FFD700'        # gold
    }
    
    markers = {
        'tail_belady': 'o',      # circle
        'tail_lru_none': 's',    # square
        'tail_lru_end': 'D',     # diamond
        'tail_lru_perfect': 'P', # plus
        'vanilla_lru': 'v',      # triangle down
        'thre_lru': 'x'          # x
    }
    
    # List of policies to include in the plot
    policies = ['tail_belady', 'tail_lru_none', 'tail_lru_end', 
                'tail_lru_perfect', 'vanilla_lru', 'thre_lru']
    
    # Create figure and axes
    fig, ax = plt.subplots(figsize=figsize)
    
    # Dictionary to store lines for legend
    policy_lines = {}
    
    # Plot each policy for each xi value
    for policy in policies:
        policy_name = policy.replace('_', ' ')
        
        # Apply custom policy name replacements
        if policy_name == "tail lru none":
            policy_name = "Tail-Optimized LRU"
        elif policy_name == "tail lru end":
            policy_name = "End-Aware T-LRU"
        elif policy_name == "tail lru perfect":
            policy_name = "Length-Aware T-LRU"
        elif policy_name == "vanilla lru":
            policy_name = "LRU"
        elif policy_name == "tail belady":
            policy_name = "Tail-Optimized Belady"
        elif policy_name == "thre lru":
            policy_name = "Threshold LRU"
        
        for i, xi in enumerate(sorted(results_by_xi.keys())):
            # Get the results dictionary for this xi
            results_dict = results_by_xi[xi]
            
            # Get the policy results for this xi
            policy_results = results_dict[policy]
            
            # Get the p-th percentile results
            p_results = policy_results[percentile]
            
            # Plot with unique color/marker combinations
            line = ax.plot(
                C_values, 
                p_results,
                marker=markers[policy], 
                color=colors[policy],
                linestyle='-' if i == 0 else '--' if i == 1 else ':' if i == 2 else '-.' if i == 3 else (0, (3, 1, 1, 1)) if i == 4 else (0, (3, 1, 1, 1, 1, 1)),
                linewidth=2,
                alpha=0.7,
                label=f"{policy_name} (xi={xi})"
            )
            
            # Store the line for legend
            if policy not in policy_lines:
                policy_lines[policy] = line[0]
    
    # Set axis labels and title
    ax.set_xlabel('Cache Capacity', fontsize=12)
    ax.set_ylabel(f'P{percentile} Uncached Tokens', fontsize=12)
    
    if title:
        ax.set_title(title, fontsize=14)
    
    # Apply log scale if specified
    if log_scale[0]:
        ax.set_xscale('log')
    if log_scale[1]:
        ax.set_yscale('log')
    
    # Set grid
    ax.grid(True, alpha=0.3)
    
    # Create legend
    ax.legend(loc='best', fontsize=10)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save figure if filename is provided
    if filename:
        plt.savefig(filename, dpi=dpi, bbox_inches='tight')
    
    # Show plot if requested
    if show_plot:
        plt.show()
    
    return fig, ax

# Example usage:

if __name__ == "__main__":
    # load the results from the pickle file
    import pickle
    with open("./results/parallel_caching.pkl", "rb") as f:
        results_dict = pickle.load(f)
        
    # import pdb; pdb.set_trace()
    C_values = [1000,2000,3000,4000,5000,6000,7000,8000]
    benchmark_labels = C_values
    tmp_str = "p90_Wildchat_xi1000_Q200"
    plot_inverted_stacked_bars_from_csv_string(csv_file='tmp_p90.csv', benchmark_labels=benchmark_labels, output_file='tail_vs_vanilla_{}.pdf'.format(tmp_str), title=tmp_str)
    # percentiles = [50, 90, 95, 99]
    # # plot the results
    # belady_results = results_dict['tail_belady']
    # lru_results = results_dict['tail_lru_none']
    # lru_end_results = results_dict['tail_lru_end']
    # lru_perfect_results = results_dict['tail_lru_perfect']
    # vanilla_belady_results = results_dict['vanilla_belady']
    # vanilla_lru_results = results_dict['vanilla_lru']
    # thre_lru_results = results_dict['thre_lru']
    # create_cache_policy_plots(C_values, belady_results, lru_results, lru_end_results, lru_perfect_results, vanilla_belady_results, vanilla_lru_results, thre_lru_results, percentiles, 'Tail-Optimized Caching vs. Vanilla Caching', 'tail_vs_vanilla.pdf', show_plot=False)
    # create_comparative_analysis_plots(C_values, results_dict, percentiles)